import collections
class Solution(object):
    def countTriplets(self, arr):
        ans = 0
        now = 0
        prefix = collections.defaultdict(list)
        prefix[0].append(-1)
        for index, num in enumerate(arr):
            now ^= num
            for key in prefix[now]:
                ans += index - key - 1
            prefix[now].append(index)
        return ans
